#Script used to generate C alpha displacement - Figure 4

import os
import numpy as np
import matplotlib.pyplot as plt
from Bio import PDB
from Bio.PDB import PDBParser, Superimposer, PDBIO
from collections import defaultdict

def calculate_displacement_and_rmsd(pdb_files, output_dir, regions=None):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Parse structures
    parser = PDBParser(QUIET=True)
    structures = [parser.get_structure(os.path.basename(f), f) for f in pdb_files]
    
    # Extract all chains with protein residues
    protein_chains = []
    for structure in structures:
        for model in structure:
            for chain in model:
                if any(residue.resname in PDB.Polypeptide.standard_aa_names for residue in chain.get_residues()):
                    protein_chains.append(chain)

    # Ensure all chains have the same residue sequence length
    min_length = min(len(chain) for chain in protein_chains)

    # Truncate each chain to the minimum length
    truncated_chains = []
    for chain in protein_chains:
        truncated_chain = PDB.Chain.Chain(chain.id)
        for residue in list(chain.get_residues())[:min_length]:
            truncated_chain.add(residue)
        truncated_chains.append(truncated_chain)

    # Update protein_chains to work with truncated_chains
    protein_chains = truncated_chains

    # Calculate C-alpha displacements
    displacements = defaultdict(list)
    superimposer = Superimposer()

    for i, chain1 in enumerate(protein_chains[:-1]):
        for chain2 in protein_chains[i+1:]:
            ca_atoms1 = [residue['CA'] for residue in chain1 if 'CA' in residue]
            ca_atoms2 = [residue['CA'] for residue in chain2 if 'CA' in residue]

            if len(ca_atoms1) != len(ca_atoms2):
                continue

            superimposer.set_atoms(ca_atoms1, ca_atoms2)
            superimposer.apply(chain2.get_atoms())
            
            rmsd = superimposer.rms
            with open(os.path.join(output_dir, 'rmsd.txt'), 'a') as rmsd_file:
                rmsd_file.write(f'RMSD between chain {i} and chain {i+1}: {rmsd:.4f}\n')

            for j, (atom1, atom2) in enumerate(zip(ca_atoms1, ca_atoms2)):
                displacement = np.linalg.norm(atom1.coord - atom2.coord)
                displacements[j].append(displacement)

    # Calculate average displacement
    average_displacements = {key: np.mean(value) for key, value in displacements.items()}

    # Write average displacement to text file
    with open(os.path.join(output_dir, 'average_displacement.txt'), 'w') as avg_file:
        for residue_index, displacement in average_displacements.items():
            avg_file.write(f'Residue {residue_index + 1}: {displacement:.4f}\n')

    # Generate new PDB files with C-alpha displacement as B factor
    for chain, pdb_file in zip(protein_chains, pdb_files):
        for i, residue in enumerate(chain):
            if 'CA' in residue:
                residue['CA'].bfactor = average_displacements.get(i, 0.0)  # Replace B factor
        
        # Save new PDB file
        pdb_io = PDBIO()
        pdb_io.set_structure(chain)
        output_pdb_path = os.path.join(output_dir, f"{os.path.basename(pdb_file)}_displacement.pdb")
        pdb_io.save(output_pdb_path)

    # Calculate average B factors across all input structures
    average_b_factors = defaultdict(list)
    for structure in structures:
        for model in structure:
            for chain in model:
                for i, residue in enumerate(chain.get_residues()):
                    if 'CA' in residue:
                        average_b_factors[i].append(residue['CA'].bfactor)
    
    # Compute averaged B factor values per residue
    averaged_b_factors = {key: np.mean(value) for key, value in average_b_factors.items()}

    # Generate C-alpha displacement divided by average B factor
    normalized_displacement = {key: average_displacements[key] / averaged_b_factors.get(key, 1)
                                for key in average_displacements.keys()}

    # Write normalized displacement to new PDB files
    for chain, pdb_file in zip(protein_chains, pdb_files):
        for i, residue in enumerate(chain):
            if 'CA' in residue:
                residue['CA'].bfactor = normalized_displacement.get(i, 0.0)  # Replace B factor
        
        # Save new PDB file
        pdb_io = PDBIO()
        pdb_io.set_structure(chain)
        output_pdb_path = os.path.join(output_dir, f"{os.path.basename(pdb_file)}_normalized_displacement.pdb")
        pdb_io.save(output_pdb_path)

    # Plot normalized displacement divided by average B factor
    residue_indices = list(normalized_displacement.keys())
    normalized_values = list(normalized_displacement.values())

    plt.figure(figsize=(10, 3))
    colors = ['b'] * len(normalized_values)  # Default color
    if regions:
        for start, end, color in regions:
            for idx in range(start, end + 1):
                if idx < len(colors):
                    colors[idx] = color

    plt.bar(residue_indices, normalized_values, width=1, color=colors)
    plt.xlabel('Residue Index')
    plt.ylabel('Normalized C-alpha Displacement by Average B-factor')
    plt.title('C-alpha Displacement Normalized by B-factor (Bar Plot)')
    plt.savefig(os.path.join(output_dir, 'normalized_displacement_plot.png'))
    plt.close()
    
# Example usage:
pdb_files = ['299_aligned_apo.pdb', '305_aligned_apo.pdb', '315_aligned_apo.pdb', '365_aligned_apo.pdb']  # Update with real paths
output_directory = 'protein_analysis_results'

# Define regions to color in the column plot (start index, end index, color)
regions = [(0, 96, 'steelblue'), (97, 260, 'deeppink'), (261, 434, 'goldenrod'), (435, 493, 'deeppink'),
           (494, 572, 'goldenrod'), (573, 613, 'deeppink'), (614, 694, 'cyan'), (695, 865, 'darkviolet'),
           (866, 1029, 'forestgreen'), (1030, 1088, 'darkviolet'), (1089, 1169, 'forestgreen'), (1170, 1200, 'darkviolet')]

calculate_displacement_and_rmsd(pdb_files, output_directory, regions)
